# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any

import torch
from torch import fx as fx
from torch import nn

# This import automatically registers `torch.ops.silly.attention`
import tests.compile.silly_attention  # noqa
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.compilation.inductor_pass import (
    InductorPass,
    get_pass_context,
)
from vllm.config import (
    VllmConfig,
    set_current_vllm_config,
)
from vllm.config.compilation import CompilationConfig, CompilationMode
from vllm.config.scheduler import SchedulerConfig
from vllm.config.utils import Range
from vllm.forward_context import set_forward_context

BATCH_SIZE = 64
MLP_SIZE = 128


@support_torch_compile
class TestModel(nn.Module):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs) -> None:
        super().__init__()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + x
        attn_output = torch.empty_like(x)
        torch.ops.silly.attention(x, x, x, attn_output)
        x = attn_output
        x = x * 3
        return x


@torch.inference_mode
def run_model(vllm_config: VllmConfig, model: nn.Module, batch_sizes: list[int]):
    with set_forward_context({}, vllm_config=vllm_config):
        model(torch.randn(BATCH_SIZE, MLP_SIZE))
        for batch_size in batch_sizes:
            model(torch.randn(batch_size, MLP_SIZE))


class PostGradRangeChecker(InductorPass):
    def __init__(self, ranges: list[Range]):
        self.ranges = ranges
        self.num_calls = 0

    def __call__(self, graph: fx.Graph):
        compile_range = get_pass_context().compile_range
        assert compile_range in self.ranges, (
            f"Compile range {compile_range} not in {self.ranges}"
        )
        self.num_calls += 1

    def uuid(self) -> str:
        state: dict[str, Any] = {}
        return InductorPass.hash_dict(state)


def test_compile_ranges(use_fresh_inductor_cache):
    post_grad_range_checker = PostGradRangeChecker(
        [
            Range(start=1, end=8),
            Range(start=16, end=16),
            Range(start=9, end=32),
            Range(start=64, end=64),
            Range(start=33, end=8192),
        ]
    )
    torch.set_default_device("cuda")
    vllm_config = VllmConfig(
        scheduler_config=SchedulerConfig(
            max_num_batched_tokens=8192,
        ),
        compilation_config=CompilationConfig(
            mode=CompilationMode.VLLM_COMPILE,
            compile_ranges_split_points=[8, 32],
            compile_sizes=[16, 64, 128],
            inductor_compile_config={
                "post_grad_custom_post_pass": post_grad_range_checker,
            },
        ),
    )

    with set_current_vllm_config(vllm_config):
        model = TestModel(vllm_config=vllm_config, prefix="").eval()
        # Number of compilations: 3 for each compile range + 2 compile sizes
        batch_sizes = [1, 4, 16, 24, 48, 64, 8192]

        with compilation_counter.expect(
            num_graphs_seen=1,
            num_piecewise_graphs_seen=1,
            num_backend_compilations=5,
        ):
            run_model(vllm_config, model, batch_sizes)
        assert post_grad_range_checker.num_calls == 5


def test_compile_config_get_compile_ranges():
    compilation_config = CompilationConfig(
        compile_ranges_split_points=[8, 32],
    )
    VllmConfig(
        scheduler_config=SchedulerConfig(
            max_num_batched_tokens=8192,
        ),
        compilation_config=compilation_config,
    )
    assert compilation_config.get_compile_ranges() == [
        Range(start=1, end=8),
        Range(start=9, end=32),
        Range(start=33, end=8192),
    ]


def test_inductor_cache_compile_ranges(monkeypatch, use_fresh_inductor_cache):
    # To force multiple compilations, we disable the compile cache
    monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")

    post_grad_range_checker = PostGradRangeChecker(
        ranges=[
            Range(start=1, end=8),
            Range(start=9, end=8192),
        ]
    )
    scheduler_config = SchedulerConfig(
        max_num_batched_tokens=8192,
    )
    torch.set_default_device("cuda")

    def create_vllm_config():
        return VllmConfig(
            scheduler_config=scheduler_config,
            compilation_config=CompilationConfig(
                mode=CompilationMode.VLLM_COMPILE,
                compile_ranges_split_points=[8],
                inductor_compile_config={
                    "post_grad_custom_post_pass": post_grad_range_checker,
                },
            ),
        )

    vllm_config_1 = create_vllm_config()
    with set_current_vllm_config(vllm_config_1):
        model1 = TestModel(vllm_config=vllm_config_1, prefix="").eval()
        batch_sizes = [1, 16]
        run_model(vllm_config_1, model1, batch_sizes)
        assert post_grad_range_checker.num_calls == 2

    post_grad_range_checker.num_calls = 0
    # Create a new vllm config with the new pass context
    vllm_config_2 = create_vllm_config()
    with set_current_vllm_config(vllm_config_2):
        model2 = TestModel(vllm_config=vllm_config_2, prefix="").eval()
        batch_sizes = [4, 32]
        run_model(vllm_config_2, model2, batch_sizes)
        # Check that cache is used, so the number of calls
        # should be 0
        assert post_grad_range_checker.num_calls == 0
